In [5]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import glob
import matplotlib.pyplot as plt
import seaborn as sns
import math
Data Preprocessing:¶
In [6]:
adata = sc.read_h5ad("/vast/palmer/pi/xiting_yan/hw568/collections_spatial_datasets/spatialDLPFC_new/adata_vis_after.h5ad")
In [7]:
image = adata.uns['spatial']['Br6522_ant']['images']['hires']
row = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_ant', 0] * adata.uns['spatial']['Br6522_ant']['scalefactors']['tissue_hires_scalef']
col = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_ant', 1] * adata.uns['spatial']['Br6522_ant']['scalefactors']['tissue_hires_scalef']
plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.scatter(row, col, color='red', s=1)
plt.show()
In [8]:
spot_diameter_fullres = adata.uns['spatial']['Br6522_ant']['scalefactors']['spot_diameter_fullres']
spot_radius_full_res = spot_diameter_fullres/2
tissue_hires_scalef = adata.uns['spatial']['Br6522_ant']['scalefactors']['tissue_hires_scalef']
spot_radius_hires = spot_radius_full_res * tissue_hires_scalef
print(f"The radius of spot in high resolution image is {spot_radius_hires:.2f} pixels.")
The radius of spot in high resolution image is 5.77 pixels.
In [9]:
sample_id = "Br6522_ant"
# Extract spot coordinates
spots_coords = adata.obsm['spatial'][adata.obs['sample_id'] == sample_id]
spot_radius = 5.77
# Count the total number of spots
num_spots = len(spots_coords)
# Display the total number of spots
print(f"Sample ID: {sample_id}")
print(f"Total Number of Spots: {num_spots}")
Sample ID: Br6522_ant Total Number of Spots: 4263
In [10]:
print(f"Image dimensions: Width = {image.shape[1]}, Height = {image.shape[0]}")
Image dimensions: Width = 1658, Height = 2000
This is clearly out of bounds, and needs to be scaled to match the pixel dimensions of the image.
In [11]:
scaling_factor_x = 1658 / np.max(spots_coords[:, 0])
scaling_factor_y = 2000 / np.max(spots_coords[:, 1])
spots_coords_scaled = spots_coords * [scaling_factor_x, scaling_factor_y]
print(f"Scaled Spots Range: Min X = {np.min(spots_coords_scaled[:, 0])}, Max X = {np.max(spots_coords_scaled[:, 0])}")
print(f"Scaled Spots Range: Min Y = {np.min(spots_coords_scaled[:, 1])}, Max Y = {np.max(spots_coords_scaled[:, 1])}")
Scaled Spots Range: Min X = 253.15847449782467, Max X = 1657.9999999999998 Scaled Spots Range: Min Y = 558.8632734122272, Max Y = 2000.0
Feature Extraction:¶
VIT part¶
In [12]:
from PIL import Image
import numpy as np
from transformers import ViTModel, ViTImageProcessor
import torch
model_name = "google/vit-base-patch16-224-in21k"
model = ViTModel.from_pretrained(model_name) # Pre-trained Vision Transformer
processor = ViTImageProcessor.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
# Extract features for selected spots
vector_representations = {}
with torch.no_grad():
for idx, (x, y) in enumerate(spots_coords_scaled):
start_x = max(0, int(x - spot_radius))
end_x = min(image.shape[1], int(x + spot_radius))
start_y = max(0, int(y - spot_radius))
end_y = min(image.shape[0], int(y + spot_radius))
if start_x >= end_x or start_y >= end_y:
print(f"Skipping spot {idx+1}: Invalid crop boundaries.")
continue
# Crop the image
cropped_image = image[start_y:end_y, start_x:end_x]
cropped_pil = Image.fromarray((cropped_image * 255).astype(np.uint8))
# Preprocess the image for ViT
inputs = processor(images=cropped_pil, return_tensors="pt", size=(224, 224))
inputs = {key: val.to(device) for key, val in inputs.items()}
# Forward pass through ViT
outputs = model(**inputs)
cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze(0).cpu().numpy()
# Store the embedding
vector_representations[f"spot_{idx+1}"] = cls_embedding
Clustering¶
In [13]:
from sklearn.decomposition import PCA
vectors = np.array(list(vector_representations.values()))
pca = PCA(n_components=2)
pcs = pca.fit_transform(vectors)
pv1 = pca.components_[0]
pv2 = pca.components_[1]
# Store the principal components in a DataFrame
principalX = pd.DataFrame(data=pcs, columns=['PC1', 'PC2'])
principalX.head(10)
Out[13]:
| PC1 | PC2 | |
|---|---|---|
| 0 | -2.300600 | 1.023820 |
| 1 | 0.352776 | -1.670902 |
| 2 | -2.373689 | 1.021301 |
| 3 | 3.593608 | 0.635154 |
| 4 | 3.587902 | 0.634732 |
| 5 | 2.577854 | -0.348189 |
| 6 | 3.548557 | 0.559539 |
| 7 | 0.646632 | -2.937982 |
| 8 | -2.120048 | 0.328270 |
| 9 | -2.395325 | 0.715087 |
In [14]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
vectors = np.array(list(vector_representations.values()))
pca = PCA(n_components=2)
reduced_vectors = pca.fit_transform(vectors)
plt.figure(figsize=(8, 6))
plt.scatter(reduced_vectors[:, 0], reduced_vectors[:, 1], c='blue', edgecolor='k', s=60)
plt.title("PCA of Spot Embeddings", fontsize=16)
plt.xlabel("Principal Component 1", fontsize=12)
plt.ylabel("Principal Component 2", fontsize=12)
plt.grid(True)
plt.show()
There are 4 types of Manual Annotation, so we set K=4.
K Means Clustering¶
In [16]:
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=4, random_state=42)
clusters = kmeans.fit_predict(reduced_vectors)
plt.figure(figsize=(8, 6))
plt.scatter(reduced_vectors[:, 0], reduced_vectors[:, 1], c=clusters, cmap='viridis', edgecolor='k', s=10)
plt.colorbar(label="Cluster Label")
plt.title("PCA of Spot Embeddings with K-Means Clusters", fontsize=16)
plt.xlabel("Principal Component 1", fontsize=12)
plt.ylabel("Principal Component 2", fontsize=12)
plt.grid(True)
plt.show()
In [32]:
from umap import UMAP
umap = UMAP(n_components=2, random_state=42)
umap_results = umap.fit_transform(vectors) # Use the original vectors before PCA
# Visualize the UMAP results
plt.figure(figsize=(8, 6))
plt.scatter(umap_results[:, 0], umap_results[:, 1], c=clusters, cmap='viridis', edgecolor='k', s=10)
plt.colorbar(label="Cluster Label")
plt.title("UMAP of Spot Embeddings", fontsize=16)
plt.xlabel("UMAP Component 1", fontsize=12)
plt.ylabel("UMAP Component 2", fontsize=12)
plt.grid(True)
plt.show
/home/ll2276/.conda/envs/new_env/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism. warn(
Out[32]:
<function matplotlib.pyplot.show(close=None, block=None)>
In [30]:
# File path
file_path = "/vast/palmer/pi/xiting_yan/hw568/collections_spatial_datasets/spatialDLPFC_new/05-shared_utilities/nonIF/spatialLIBD_ManualAnnotation_Br6522_ant_wrinkle.csv"
data = pd.read_csv(file_path)
print(data.head(50))
annotation_counts = data['ManualAnnotation'].value_counts()
print(data.shape[0])
print(annotation_counts)
sample_id spot_name ManualAnnotation 0 Br6522_ant AAACGCTGGGCACGAC-1 Fold_1 1 Br6522_ant AAACGGGCGTACGGGT-1 Wrinkle_8 2 Br6522_ant AAACTAACGTGGCGAC-1 Shear_3 3 Br6522_ant AAAGGCTCTCGCGCCG-1 Wrinkle_5 4 Br6522_ant AAATAGGGTGCTATTG-1 Wrinkle_5 5 Br6522_ant AACATAGCGTGTATCG-1 Shear_1 6 Br6522_ant AACCATGGGATCGCTA-1 Wrinkle_1 7 Br6522_ant AACCCGAGCAGAATCG-1 Wrinkle_7 8 Br6522_ant AACCTTTAAATACGGT-1 Wrinkle_5 9 Br6522_ant AACGAAAGTCGTCCCA-1 Wrinkle_6 10 Br6522_ant AACGACCTCCTAGCCG-1 Fold_1 11 Br6522_ant AACGTCGCTGCACTTC-1 Fold_1 12 Br6522_ant AACTAGGCTTGGGTGT-1 Wrinkle_5 13 Br6522_ant AACTTGCGTTCTCGCG-1 Wrinkle_1 14 Br6522_ant AAGAAAGTTTGATGGG-1 Shear_2 15 Br6522_ant AAGCGCAGGGCTTTGA-1 Wrinkle_8 16 Br6522_ant AAGCGGCGTCATGGGT-1 Shear_3 17 Br6522_ant AAGCGTCCCTCATCGA-1 Wrinkle_5 18 Br6522_ant AAGCTATGGATTGACC-1 Shear_2 19 Br6522_ant AAGGGTTTGATTTCAG-1 Wrinkle_6 20 Br6522_ant AAGTAGAAGACCGGGT-1 Wrinkle_2 21 Br6522_ant AAGTTTATGGGCCCAA-1 Wrinkle_8 22 Br6522_ant AATACCTGATGTGAAC-1 Fold_1 23 Br6522_ant AATAGGCACGACCCTT-1 Shear_2 24 Br6522_ant AATAGTCCGTCCCGAC-1 Wrinkle_5 25 Br6522_ant AATCCCGCTCAGAGCC-1 Shear_1 26 Br6522_ant AATCGAGGTCTCAAGG-1 Shear_1 27 Br6522_ant AATCGCCTCAGCGCCA-1 Wrinkle_4 28 Br6522_ant AATCGGTATAGCCCTC-1 Fold_1 29 Br6522_ant AATCGTGAGCCGAGCA-1 Wrinkle_8 30 Br6522_ant AATGAGTTCGCATATG-1 Wrinkle_1 31 Br6522_ant AATGTTAAGACCCTGA-1 Shear_1 32 Br6522_ant AATTACGAGACCCATC-1 Wrinkle_6 33 Br6522_ant AATTATACCCAGCAAG-1 Shear_1 34 Br6522_ant AATTGCAGCAATCGAC-1 Wrinkle_8 35 Br6522_ant ACAAGTAATTGTAAGG-1 Shear_2 36 Br6522_ant ACAATGATTCTTCTAC-1 Shear_3 37 Br6522_ant ACAATTGTGTCTCTTT-1 Wrinkle_5 38 Br6522_ant ACACCCGAGAAATCCG-1 Wrinkle_2 39 Br6522_ant ACAGGCTTGCCCGACT-1 Wrinkle_5 40 Br6522_ant ACATAAGTCGTGGTGA-1 Wrinkle_6 41 Br6522_ant ACATACAATCAAGCGG-1 Shear_1 42 Br6522_ant ACATCCCGGCCATACG-1 Wrinkle_5 43 Br6522_ant ACATCGCAATATTCGG-1 Shear_2 44 Br6522_ant ACCAACCGCACTCCAC-1 Wrinkle_8 45 Br6522_ant ACCACACGGTTGATGG-1 Shear_1 46 Br6522_ant ACCAGTGCCCGGTCAA-1 Shear_3 47 Br6522_ant ACCATCGTATATGGTA-1 Wrinkle_4 48 Br6522_ant ACCCGGATGACGCATC-1 Wrinkle_4 49 Br6522_ant ACCCGGTTACACTTCC-1 Wrinkle_4 547 ManualAnnotation Wrinkle_5 90 Wrinkle_8 79 Shear_1 73 Wrinkle_4 72 Shear_2 56 Fold_1 49 Wrinkle_1 38 Wrinkle_2 26 Shear_3 24 Wrinkle_3 16 Wrinkle_6 11 Wrinkle_7 10 Wrinkle_9 3 Name: count, dtype: int64
In [24]:
from sklearn.preprocessing import LabelEncoder
# Encode categorical features
data['sample_id'] = LabelEncoder().fit_transform(data['sample_id'])
data['spot_name'] = LabelEncoder().fit_transform(data['spot_name'])
labels = LabelEncoder().fit_transform(data['ManualAnnotation'])
# Extract features for UMAP
features = data[['sample_id', 'spot_name']]
# Apply UMAP
reducer = umap.UMAP(n_components=2, random_state=42)
umap_results = reducer.fit_transform(features)
# Plot UMAP results
plt.figure(figsize=(10, 8))
scatter = plt.scatter(
umap_results[:, 0],
umap_results[:, 1],
c=labels,
cmap="Spectral",
edgecolor="k",
s=50,
)
plt.colorbar(scatter, label="Class Labels (ManualAnnotation)")
plt.title("UMAP Visualization of ManualAnnotation", fontsize=16)
plt.xlabel("UMAP Component 1", fontsize=12)
plt.ylabel("UMAP Component 2", fontsize=12)
plt.grid(True)
plt.show()
/home/ll2276/.conda/envs/new_env/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism. warn(
In [9]:
Br6522_mid_image = adata.uns['spatial']['Br6522_mid']['images']['hires']
row_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 0] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']
col_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 1] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']
plt.figure(figsize=(6, 6))
plt.imshow(Br6522_mid_image)
plt.scatter(row_mid, col_mid, color='red', s=1)
plt.show()
In [9]:
Br6522_mid_image = adata.uns['spatial']['Br6522_mid']['images']['hires']
row_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 0] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']
col_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 1] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']
plt.figure(figsize=(6, 6))
plt.imshow(Br6522_mid_image)
plt.scatter(row_mid, col_mid, color='red', s=1)
plt.show()
In [9]:
Br6522_mid_image = adata.uns['spatial']['Br6522_mid']['images']['hires']
row_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 0] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']
col_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 1] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']
plt.figure(figsize=(6, 6))
plt.imshow(Br6522_mid_image)
plt.scatter(row_mid, col_mid, color='red', s=1)
plt.show()
In [10]:
Br8667_post_image = adata.uns['spatial']['Br8667_post']['images']['hires']
row_post = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br8667_post', 0] * adata.uns['spatial']['Br8667_post']['scalefactors']['tissue_hires_scalef']
col_post = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br8667_post', 1] * adata.uns['spatial']['Br8667_post']['scalefactors']['tissue_hires_scalef']
plt.figure(figsize=(6, 6))
plt.imshow(Br8667_post_image)
plt.scatter(row_post, col_post, color='red', s=1)
plt.show()
In [3]:
# Get all sample IDs
sample_ids = list(adata.uns['spatial'].keys())
# Define the grid size
grid_size = 8
num_samples = len(sample_ids)
num_rows = math.ceil(num_samples / grid_size)
# Create a figure for the grid
fig, axes = plt.subplots(num_rows, grid_size, figsize=(grid_size * 2, num_rows * 2))
# Flatten the axes for easier indexing
axes = axes.flatten()
# Iterate over each sample and plot
for i, sample_id in enumerate(sample_ids):
# Access the image and spatial data
sample_data = adata.uns['spatial'][sample_id]
image = sample_data['images']['hires']
row = (
adata.obsm['spatial'][adata.obs['sample_id'] == sample_id, 0]
* sample_data['scalefactors']['tissue_hires_scalef']
)
col = (
adata.obsm['spatial'][adata.obs['sample_id'] == sample_id, 1]
* sample_data['scalefactors']['tissue_hires_scalef']
)
# Plot the image and points in the corresponding subplot
ax = axes[i]
ax.imshow(image)
ax.set_title(f"Sample: {sample_id}", fontsize=8)
ax.axis("off")
# Turn off any unused subplots
for j in range(num_samples, len(axes)):
axes[j].axis("off")
# Adjust layout and display
plt.tight_layout()
plt.show()